require(colorout)
require(sqldf)
require(arm)
require(rstan)

remove(list=objects())
options(digits=2, scipen=9, width=110, java.parameters = "-Xrs")

################################################################################################################
# prep survey data

dat <- data.table::fread("data/yg_pamrp_surveydata.txt", data.table=FALSE)

# mean value replacement for missing continuous data (not the best, but not much here and not the point of this paper)
xs_num <- c("income_num", "prob_college", "prob_married", "age_num", "num_dem_votes", "num_rep_votes", 
  "pctwhite", "pctblack", "pcthispaniclatino", "pctmarriedwithchildren", "pctnoncitizenforeignborn", 
  "pctpublictranstowork", "marriedcouplehomeowners", "medianhhincome", "pctpublicassistance")
xs_num_mu <- sapply(c(xs_num, "dem2way2008"), function(i) mean(dat[,i], na.rm=TRUE))
xs_num_sd <- sapply(c(xs_num, "dem2way2008"), function(i) sd(dat[,i], na.rm=TRUE))
for (i in c(xs_num, "dem2way2008"))
  dat[is.na(dat[,i]), i] <- xs_num_mu[i]

# turn categoricals into numeric, prepping for stan
xs_cat <- c("gender", "race", "income", "marital", "education", "age", "partyreg", 
  "dprimadv", "last2votes", "state", "region", "statetype")
cat_lookups <- sapply(xs_cat, function(i) sort(unique(dat[,i])))
dat_cats <- dat[,xs_cat]
for (i in 1:ncol(dat_cats))
  dat_cats[,i] <- as.numeric(as.factor(dat_cats[,i]))

# make 2-way interactions between categoricals
interactions <- as.data.frame(gtools::combinations(n=length(xs_cat), r=2, v=xs_cat), stringsAsFactors=FALSE)
colnames(interactions) <- c("x1", "x2")
drop <- interactions$x1 %in% c("state", "statetype", "region") & interactions$x2 %in% c("state", "statetype", "region")
interactions <- interactions[!drop,]
interactions$x <- paste0(interactions$x1, "__", interactions$x2)

dat_interactions <- as.data.frame(array(NA, c(nrow(dat), nrow(interactions))))
colnames(dat_interactions) <- interactions$x
interaction_lookups <- list()
for (i in 1:nrow(interactions)) {
  dat_interactions[,i] <- paste0(dat[, interactions$x1[i]], "__", dat[, interactions$x2[i]])
  interaction_lookups[[i]] <- sort(unique(dat_interactions[,i]))
  dat_interactions[,i] <- as.numeric(as.factor(dat_interactions[,i]))
}

# put 2008 county vote share on the logit scale, mean-center other numeric xs
logit <- function(x)
  return(-1 * log(1/pmin(0.995, pmax(0.005, x)) - 1))
dat$dem2way2008 <- logit(dat$dem2way2008)

dat_numeric <- dat[,xs_num]
for (i in xs_num)
  dat_numeric[,i] <- (dat_numeric[,i] - xs_num_mu[i]) / (2 * xs_num_sd[i])

# which xs will we use for varying slopes (on previous county vote)?
xs_slope <- c("partyreg", "statetype")

################################################################################################################
# compile stan model

cats_full <- c(xs_cat, interactions$x)
stan_code <- paste0("
  data {
    int<lower=0> n;
    int<lower=0> k;
    int<lower=0, upper=1> y[n];
    vector[n] z_dem2wayprev; 
    matrix[n, k] Z;
    ", paste0(paste0("int<lower=0> ", cats_full, "[n]"), collapse=";\n"), "; 
    ", paste0(paste0("int<lower=0> n_", cats_full), collapse=";\n"), "; 
  }
  parameters {
    vector[k] beta_z;

    // varying intercepts for all categoricals and 2-way interactions
    real alpha;
    ", paste0(paste0("vector[n_", cats_full, "] alpha_", cats_full), collapse=";\n"), "; 

    // varying slopes on previous county vote share, for a smaller set
    real beta;
    ", paste0(paste0("vector[n_", xs_slope, "] beta_", xs_slope), collapse=";\n"), "; 

    // variance hyper-parameters
    real<lower=0> sigma_alpha[", length(cats_full), "];
    real<lower=0> sigma_sigma_alpha;
    real<lower=0> sigma_beta[", length(xs_slope), "];
    real<lower=0> sigma_sigma_beta;
  }
  model {
    vector[n] yhat;

    ", paste0(paste0("alpha_", cats_full, " ~ normal(0,1)"), collapse=";\n"), "; 
    ", paste0(paste0("beta_", xs_slope, " ~ normal(0,1)"), collapse=";\n"), "; 
    sigma_alpha ~ student_t(8,0,1);
    sigma_beta ~ student_t(8,0,1);

    for (i_n in 1:n) {
      yhat[i_n] = 
        alpha + ", paste0(paste0("alpha_", cats_full, "[", cats_full, "[i_n]] * sigma_alpha[", 1:length(cats_full), "] * sigma_sigma_alpha"), collapse=" + "), " + 
        z_dem2wayprev[i_n] * (beta + ", paste0(paste0("beta_", xs_slope, "[", xs_slope, "[i_n]] * sigma_beta[", 1:length(xs_slope), "] * sigma_sigma_beta"), collapse=" + "), ");
      for (i_k in 1:k)
        yhat[i_n] = yhat[i_n] + Z[i_n, i_k] * beta_z[i_k];
    }

    y ~ bernoulli_logit(yhat);
  }
  generated quantities {
    vector[n] yhat;
    for (i_n in 1:n) {
      yhat[i_n] = 
        alpha + ", paste0(paste0("alpha_", cats_full, "[", cats_full, "[i_n]] * sigma_alpha[", 1:length(cats_full), "] * sigma_sigma_alpha"), collapse=" + "), " + 
        z_dem2wayprev[i_n] * (beta + ", paste0(paste0("beta_", xs_slope, "[", xs_slope, "[i_n]] * sigma_beta[", 1:length(xs_slope), "] * sigma_sigma_beta"), collapse=" + "), ");
      for (i_k in 1:k)
        yhat[i_n] = yhat[i_n] + Z[i_n, i_k] * beta_z[i_k];
    }
  }
")

################################################################################################################
# put input data into stan_data format and run it

stan_data <- list(
  n=nrow(dat), 
  k=ncol(dat_numeric),
  y=dat$y,
  z_dem2wayprev=dat$dem2way2008,
  Z=as.matrix(dat_numeric))
for (i in xs_cat)
  stan_data[[i]] <- dat_cats[,i]
for (i in interactions$x)
  stan_data[[i]] <- dat_interactions[,i]
for (i in xs_cat)
  stan_data[[paste0("n_", i)]] <- max(dat_cats[,i])
for (i in interactions$x)
  stan_data[[paste0("n_", i)]] <- max(dat_interactions[,i])

# THIS COMMENTED CODE WAS USED TO TEST THE MODEL USING ADVI
# ADVI IS FASTER THAN MCMC BUT PRODUCES LESS RELIABLE ESTIMATES
# ONCE THE CODE WAS WORKING, I SWITCHED TO MCMC FOR FINAL ESTIMATES

# sm <- stan_model(model_code=stan_code)
# stan_seed <- round(runif(1, 0, 20000))
# M <- vb(sm, data=stan_data, output_samples=1000, seed=stan_seed, iter=20000)

# require(foreach)
# require(doMC)
# sm <- stan_model(model_code=stan_code)
# registerDoMC(5)
# st <- system.time(M <- foreach(i_M = 1:5) %dopar% {
#   stan_seed <- round(runif(1, 0, 20000))
#   M_out <- capture.output(M <- vb(sm, data=stan_data, output_samples=1000, seed=stan_seed, iter=20000))
#   return(list(stan_seed=stan_seed, M=M, M_out=M_out))
# })
# save.image("yg_pa_mrp_advi.RData")

set.seed(12345)
st <- system.time(M <- stan(
  model_code=stan_code,
  data=stan_data,
  chains=6,
  iter=1000,
  cores=6, 
  verbose=TRUE
))
save.image("yg_pa_mrp_mcmc.RData")
